/*
    Copyright 2006-2008 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

#include <sstream>
#include <iostream>
#include "CCfits/CCfits"
#include "interpolatort.h"
#include "blitz-defines.h"
#include "fits-utilities.h"
#include "blitz-fits.h"

using namespace std;
using namespace mcrx;
using namespace CCfits;
using namespace blitz;

typedef interpolator<array_2, double, 1> T_SED_interpolator;
typedef interpolator<array_1, double, 1> T_m_interpolator;


// keeps the SED and mass for a given metallicity population set (the
// quantities are arrays containing all times)
class time_model {
public:
  double metallicity;
  array_1 current_mass; // the mass fraction still present
  array_2 SED; // the SED grid of the population 
  time_model (double z, const array_1 m, const array_2& r): metallicity(z), 
							    current_mass(m),
							    SED(r) {};
  bool operator< (const time_model& rhs) const {
    return metallicity<rhs.metallicity;};
};


// initializes the SED interpolator 
void initialize_interpolator(T_SED_interpolator& SEDi, 
			     T_m_interpolator& mi, 
			     const vector<time_model>& models)
{
  vector<double> metallicities;
  for (int i = 0; i < models.size(); ++i) {
    metallicities.push_back(log10(models [i].metallicity) );
  }
  SEDi.setAxis(0, metallicities );
  mi.setAxis(0, metallicities );
  for (int i = 0; i < models.size(); ++i) {
    SEDi.setPoint(i, models[i].SED);
    mi.setPoint(i, models[i].current_mass);
  }
}


int main(int argc, char** argv)
{
  if (argc<4) {
    cout << "Usage:\nassemble_models <number of interpolated metallicities> <outfile> <infiles>..." << endl;
    exit(1);
  }
  assert (argc >= 4);
  
  // First argument is the number of interpolated SEDs between the
  // supplied ones
  istringstream countstr(argv[1]);
  int n_interp=0;
  countstr >> n_interp;

  const string output_file = argv[2];

  vector<string> model_names;
  for (int i = 3; i < argc; ++i)
    model_names.push_back(argv [i] );

  cout << "Assembling " << model_names .size() 
       << " stellar model files into file " 
       << output_file << endl;

  // Load models
  vector<time_model> models;
  array_1 times;
  array_1 lambda;
  string time_unit, lambda_unit, sed_unit, mass_unit;
  for (vector<string>::const_iterator i = model_names.begin();
       i != model_names.end(); ++i) {
    FITS f (*i, Read);

    // read metallicity
    ExtHDU& model_hdu = open_HDU (f, "STELLARMODEL");
    double metallicity;
    model_hdu.readKey("track_metallicity", metallicity);
    bool log_flux;
    model_hdu.readKey("logflux",  log_flux);

    ExtHDU& spectrum = open_HDU (f, "SED");
    if(times.size()== 0) {
      // we still have not read time column, do that
      Column & time_c = spectrum.column("time" );
      read (time_c, times);
      time_unit = time_c.unit(); 
    }
    else {
      // check that times are consistent
      Column & time_c = spectrum.column("time" );
      array_1 temp;
      read (time_c, temp);
      if ((times.size() != temp.size()) || any (times != temp)) {
	cerr << "Error: SED times in file " <<*i << " are inconsistent!"  
	     << endl;
	exit (1);
      }
    }

    ExtHDU& lambda_hdu = open_HDU (f, "LAMBDA");
    if(lambda.size()== 0) {
      // we still have not read lambda column, do that
      Column & lambda_c = lambda_hdu.column("lambda" );
      read (lambda_c, lambda);
      lambda_unit = lambda_c.  unit ();
    }
    else {
      // check that times are consistent
      Column & lambda_c = lambda_hdu.column("lambda" );
      array_1 temp;
      read (lambda_c, temp);
      if ((lambda.size() != temp.size()) || any (lambda != temp)) {
	cerr << "Error: SED wavelengths in file " <<*i << " are inconsistent!"  
	     << endl;
	exit (1);
      }
    }

    // read SED
    Column & sed_c =
      spectrum.column((log_flux?  string ("log "): string ("")) + "L_lambda");
    array_2 sed;
    read (sed_c, sed);
    // make sed unconditionally logarithmic
    if(!log_flux)
      sed=log10(sed);
    sed_unit = sed_c.unit(); 

    // read mass fraction
    Column & mass_c =
      spectrum.column("current_mass");
    array_1 cm;
    read (mass_c, cm);
    // take log because we interpolate in log space
    cm = log10(cm);
    mass_unit = mass_c.unit();

    models.push_back( time_model(metallicity, cm, sed) );
  }
  
  // sort models in ascending metallicity
  sort(models.begin(), models.end());
  
  // Initialize the interpolator
  T_SED_interpolator si;
  T_m_interpolator mi;
  initialize_interpolator(si, mi, models);

  // generate metallicity vector to be interpolated into
  vector<double> new_metallicities;
  for (int i = 0; i < models.size()-1; ++i) {
    const double l = models [i].metallicity;
    const double h = models [i+ 1].metallicity;
    new_metallicities.push_back(l );
    for (int j = 1; j <= n_interp; ++j) {
      const double Z = pow (10., j*(log10(h) - log10(l))/(n_interp+ 1) + 
			    log10(l));
      new_metallicities.push_back(Z );
    }
  }
  new_metallicities.push_back(models.back().metallicity );

  // and interpolate
  array_3 new_SED (times.size(), new_metallicities.size(), lambda.size());
  array_2 new_mass (times.size(), new_metallicities.size());
  for (int i = 0; i < new_metallicities.size(); ++i) {
    cout << "Interpolating SED for metallicity " << new_metallicities [i] 
	 << endl;
    const array_2 tempsed(si.interpolate(log10 (new_metallicities [i]) ));
    const array_1 tempm(pow(10,mi.interpolate(log10 (new_metallicities [i]))));
    assert(all(tempm>=0));
    assert(all(tempm<=1));
    new_SED (Range::all (), i, Range::all ()) = tempsed;
    new_mass (Range::all (), i) = tempm;
  }

  // dump output
  /*
  ofstream dump("assdump");
  for (int i = 0; i < new_SED.extent(firstDim); ++i)
    for (int j = 0; j < new_SED.extent(secondDim); ++j) {
      for (int k = 0; k < new_SED.extent(thirdDim); ++k)
	dump << times(i) << '\t' << new_metallicities [j] << '\t' << lambda (k)
	     << '\t' << new_SED (i, j, k) << '\n';
      dump << "\n\n";
    }
  */

  // Write output file
  cout << "Writing output file " << output_file << endl;
  FITS output ("!"+output_file, Write);
  output.pHDU().addKey("FILETYPE", string ("STELLARMODEL"),
		       "This file contains the SED of a stellar model ");

  FITS input (model_names.front(), Read);
  output.copy(open_HDU(input, "STELLARMODEL" ) );
  ExtHDU& sm = open_HDU (output, "STELLARMODEL");
  // remove keywords and that don't make sense in this file 
  try {
  sm.deleteKey ("track_metallicity");
  } catch (...) {}
  try {sm.deleteKey ("ewidth_file");} catch (...) {}
  try {sm.deleteKey ("OUTFILE");} catch (...) {}
  try {sm.deleteKey ("model_designation");} catch (...) {}
  try {sm.deleteKey ("SEDFILE");} catch (...) {}
  sm.writeHistory("This file was assembled from these single-metallicity stellar model files by assemble_models:");
  for (int i = 0; i < model_names.size(); ++i)
    sm.writeHistory(model_names [i]);
  sm.addKey("ninterp", n_interp,
	    "[] Number of interpolated SEDs in metallicity");
  
  vector<long> naxes;
  naxes.push_back(new_SED.extent(firstDim ) );
  naxes.push_back(new_SED.extent(secondDim ) );
  naxes.push_back(new_SED.extent(thirdDim ) );
  ExtHDU* sed_HDU = output.addImage("SED", DOUBLE_IMG, naxes);
  write(*sed_HDU, new_SED);
  sed_HDU->addKey("UNIT", sed_unit,"");

  naxes.resize(0);
  naxes.push_back(new_mass.extent(firstDim ) );
  naxes.push_back(new_mass.extent(secondDim ) );
  ExtHDU* m_HDU = output.addImage("current_mass", DOUBLE_IMG, naxes);
  write(*m_HDU, new_mass);
  m_HDU->addKey("UNIT", mass_unit,"");

  ExtHDU*axes_HDU = output.addTable("AXES", 0);
  axes_HDU->addKey("n_time", times.size(), "");
  axes_HDU->addColumn(Tdouble, "time", 1, time_unit );
  write (axes_HDU->column("time" ), times, 1);
  axes_HDU->addKey("n_metallicity", new_metallicities.size(), "");
  axes_HDU->addColumn(Tdouble, "metallicity", 1, "" );
  axes_HDU->column("metallicity" ).write(new_metallicities, 1);
  axes_HDU->addKey("n_lambda", lambda.size(), "");
  axes_HDU->addColumn(Tdouble, "lambda", 1, lambda_unit );
  write (axes_HDU->column("lambda" ), lambda, 1);

}
